Lifecycle for predictive models

Day 1: Predictive modeling

Get prepared

install.packages(c("tableone", "DALEX", "ggplot2", "partykit", "mlr3", "mlr3learners", "ranger", "mlr3tuning", "paradox"))

Part 1: Introduction to predictive modeling + EDA

The purpose of this tutorial is to present a life cycle of a single predictive model for a problem related to binary classification (but I deliberately don’t mention logistic regression). Along the way, we will tackle various interesting topics such as model training, model verification, hyperparameter tuning, and exploratory model analysis.

The examples presented here are inspired by three textbooks books: The Elements of Statistical Learning with mathematical foundations, The mlr3 book, presenting software package mlr3 designed for advanced model programming, Explanatory Model Analysis overviews methods for model exploration and visualisation. Note that responsible modelling requires knowledge that cannot be learned in two days.
So, after this introduction, I highly recommend checking out these books.

Why should I care?

Predictive models have been used throughout entire human history. Priests in Egypt were predicting when the flood of the Nile or a solar eclipse would come. Developments in statistics, increasing the availability of datasets, and increasing computing power allow predictive models to be built faster and faster.

Today, predictive models are used virtually everywhere. Planning the supply chain for a large corporation, recommending lunch or a movie for the evening, or predicting traffic jams in a city. Newspapers are full of interesting applications.

But how are such predictive models developed? In the following sections, we will go through a life cycle of a predictive model. From the concept phase, through design, training, checking, to the deployment. For this example, we will use the data set on the risk of death for Covid-19 patients after SARS-COV-2 infection. But keep in mind that the data presented here is artificial. It is generated to mirror relations in real data, but do not contain real observations for real patients. Still, it should be an interesting use-case to discuss a typical lifetime of a predictive model.

Tools

These materials are based on two R packages: mlr3 for model training and DALEX for model visualization and explanation. But there are more packages with similar functionalities, for modelling other popular choices are mlr, tidymodels and caret while for the model explanation you will find lots of interesting features in flashlight and iml.

The problem

The life cycle of a predictive model begins with a well-defined problem. In this example, we are looking for a model that assesses the risk of death after diagnosed covid. We don’t want to guess who will survive and who won’t. We want to construct a score that allows us to sort patients by risk of death.

Why do we need such a model? It could have many applications! Those at higher risk of death could be given more protection, such as providing them with pulse oximeters or preferentially vaccinating them.

Load packages

library("tableone")
library("DALEX")
library("ggplot2")
library("partykit")
library("mlr3")
library("mlr3learners")
library("ranger")
library("mlr3tuning")
library("paradox")

set.seed(1313)

Conception

Before we build any model, even before we touch any data we should first determine for what purpose we will build a predictive model.

It is very important to define the objective before we sit down to programming, because later it is easy to get lost in setting function parameters and dealing with all these details that we need to do. It is easy to lose sight of the long-term goal.

So, first: Define the objective.

For the purpose of these exercises, I have selected data on the covid pandemic. Imagine that we want to determine the order of vaccination. In this example we want to create a predictive model that assesses individual risks because we would like to rank patients according to their risk.

To get a model that gives a best ranking we will use the AUC measure to evaluate model performance. What exactly the AUC is I’ll talk about a little later, right now the key thing is that we’re interested in ranking of patients based on their risk score.

Read the data

To build a model we need good data. In Machine Learning, the word good means a large amount of representative data. Collecting representative data is not easy and often requires designing an appropriate experiment.

The best possible scenario is that one can design and run an experiment to collect the necessary data. In less comfortable situations, we look for “natural experiments,” i.e., data that have been collected for another purpose but that can be used to build a model. Here we will use the data= collected through epidemiological interviews. There will be a lot of data points and it should be fairly representative, although unfortunately it only involves symptomatic patients who are tested positive for SARS-COV-2.

For the purposes of this exercise, I have prepared two sets of characteristics of patients infected with covid. It is important to note that these are not real patient data. This is simulated data, generated to have relationships consistent with real data (obtained from NIH), but the data itself is not real. Fortunately, they are sufficient for the purposes of our exercise.

The data is divided into two sets covid_spring and covid_summer. The first is acquired in spring 2020 and will be used as training data while the second dataset is acquired in summer and will be used for validation. In machine learning, model validation is performed on a separate data set. This controls the risk of overfitting an elastic model to the data. If we do not have a separate set then it is generated using cross-validation, out of sample or out of time techniques.

  • covid_spring.csv corresponds to covid mortality data from spring 2020. We will use this data for model training.
  • covid_summer.csv corresponds to covid mortality data from summer 2020. We will use this data for model validation.
covid_spring <- read.table("covid_spring.csv", sep =";", header = TRUE, stringsAsFactors = TRUE)
covid_summer <- read.table("covid_summer.csv", sep =";", header = TRUE, stringsAsFactors = TRUE)

Explore the data

Before we start any serious modeling, it is worth looking at the data first. To do this, we will do a simple EDA. In R there are many tools to do data exploration, I value packages that support so called table one.

library("tableone")

table1 <- CreateTableOne(vars = colnames(covid_spring)[1:11],
                         data = covid_spring,
                         strata = "Death")
print(table1)
##                                    Stratified by Death
##                                     No            Yes           p      test
##   n                                  9487           513                    
##   Gender = Male (%)                  4554 (48.0)    271 (52.8)   0.037     
##   Age (mean (SD))                   44.19 (18.32) 74.44 (13.27) <0.001     
##   Cardiovascular.Diseases = Yes (%)   839 ( 8.8)    273 (53.2)  <0.001     
##   Diabetes = Yes (%)                  260 ( 2.7)     78 (15.2)  <0.001     
##   Neurological.Diseases = Yes (%)     127 ( 1.3)     57 (11.1)  <0.001     
##   Kidney.Diseases = Yes (%)           111 ( 1.2)     62 (12.1)  <0.001     
##   Cancer = Yes (%)                    158 ( 1.7)     68 (13.3)  <0.001     
##   Hospitalization = Yes (%)          2344 (24.7)    481 (93.8)  <0.001     
##   Fever = Yes (%)                    3314 (34.9)    335 (65.3)  <0.001     
##   Cough = Yes (%)                    3062 (32.3)    253 (49.3)  <0.001     
##   Weakness = Yes (%)                 2282 (24.1)    196 (38.2)  <0.001

During modeling, exploration often takes the most time. In this case, we will limit ourselves to some simple graphs.

ggplot(covid_spring, aes(Age)) +
  geom_histogram() +
  ggtitle("Histogram of age")

ggplot(covid_spring, aes(Age, fill = Death)) +
  geom_histogram() +
  ggtitle("Histogram of age")

ggplot(covid_spring, aes(Age, fill = Death)) +
  geom_histogram(position = "fill") +
  ggtitle("Histogram of age")

library("pheatmap")
pheatmap((covid_spring[,3:11] == "Yes") + 0)

Transform the data

One of the most important rules to remember when building a predictive model is: Do not condition on future!

Variables like Hospitalization or Cough are not good predictors, beacuse they are not known in advance.

covid_spring <- covid_spring[,c("Gender", "Age", "Cardiovascular.Diseases", "Diabetes",
               "Neurological.Diseases", "Kidney.Diseases", "Cancer",
               "Death")]
covid_summer <- covid_summer[,c("Gender", "Age", "Cardiovascular.Diseases", "Diabetes",
               "Neurological.Diseases", "Kidney.Diseases", "Cancer",
               "Death")]

Your turn

  • Plot Age distribution for covid_spring and covid_summer.
  • Calculate tableone for covid_spring and covid_summer.
  • (extra) In the DALEX package you will find titanic_imputed dataset. Calculate tableone for this dataset.

Part 2: Hello model! First predictive model + How to measure performance

We will think of a predictive model as a function that computes a certain prediction for certain input data. Usually, such a function is built automatically based on the data. But technically the model can be any function defined in any way. The first model will be based on statistics collected by the CDC (CDC stands for Centers for Disease Control and Prevention. You will find a set of useful statistics related to Covid mortality on this page) that determine mortality in different age groups.

In many cases, you do not need data to create a model. Just google some information about the problem.

It turns out that CDC has some decent statistics about age-related mortality. These statistics will suffice as a first approximation of our model.

https://www.cdc.gov/coronavirus/2019-ncov/covid-data/investigations-discovery/hospitalization-death-by-age.html

Lesson 1: Often you don’t need individual data to build a good model.

Create a model

What is a predictive model? We will think of it as a function that takes a set of numbers as input and returns a single number as the result - the score.

cdc_risk_ind <- function(x, base_risk = 0.00003) {
  if (x$Age < 4.5) return(2 * base_risk)
  if (x$Age < 17.5) return(1 * base_risk)
  if (x$Age < 29.5) return(15 * base_risk)
  if (x$Age < 39.5) return(45 * base_risk)
  if (x$Age < 49.5) return(130 * base_risk)
  if (x$Age < 64.5) return(400 * base_risk)
  if (x$Age < 74.5) return(1100 * base_risk)
  if (x$Age < 84.5) return(2800 * base_risk)
  7900 * base_risk
}
x <- data.frame(Age = 25)
cdc_risk_ind(x)
## [1] 0.00045

The same function can be written in a slightly more compact form as (now it works on many rows)

cdc_risk <- function(x, base_risk = 0.00003) {
  bin <- cut(x$Age, c(-Inf, 4.5, 17.5, 29.5, 39.5, 49.5, 64.5, 74.5, 84.5, Inf))
  relative_risk <- c(2, 1, 15, 45, 130, 400, 1100, 2800, 7900)[as.numeric(bin)] 
  relative_risk * base_risk
}

# check it
x <- data.frame(Age = c(25,45,85))
cdc_risk(x)
## [1] 0.00045 0.00390 0.23700
summary(cdc_risk(covid_spring))
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## 0.00003 0.00135 0.00390 0.01849 0.01200 0.23700
table(Death = covid_spring$Death, 
      Prediction.above.005 = cdc_risk(covid_spring) > 0.05)
##      Prediction.above.005
## Death FALSE TRUE
##   No   8946  541
##   Yes   237  276

Wrap the model

In R, we have many tools for creating models. The problem with them is that these tools are created by different people and return results in different structures. So in order to work uniformly with the models we need to package the model in such a way that it has a uniform interface.

Different models have different APIs.

But you need One API to Rule Them All!

The DALEX library provides a unified architecture to explore and validate models using different analytical methods.

More info

library("DALEX")
model_cdc <-  DALEX::explain(cdc_risk,
                   predict_function = function(m, x) m(x),
                   data  = covid_summer,
                   y     = covid_summer$Death == "Yes",
                   type  = "classification",
                   label = "CDC")
## Preparation of a new explainer is initiated
##   -> model label       :  CDC 
##   -> data              :  10000  rows  8  cols 
##   -> target variable   :  10000  values 
##   -> predict function  :  function(m, x) m(x) 
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package Model of class: function package unrecognized , ver. Unknown , task regression (  default  ) 
##   -> model_info        :  type set to  classification 
##   -> model_info        :  Model info detected classification task but 'y' is a logical . Converted to numeric.  (  NOTE  )
##   -> predicted values  :  numerical, min =  3e-05 , mean =  0.01480215 , max =  0.237  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.237 , mean =  0.008497847 , max =  0.99955  
##   A new explainer has been created! 
predict(model_cdc, x)
## [1] 0.00045 0.00390 0.23700

Model performance

The evaluation of the model performance for the classification is based on different measures than for the regression.

For regression, commonly used measures are Mean squared error MSE

\[MSE(f) = \frac{1}{n} \sum_{i}^{n} (f(x_i) - y_i)^2 \]

and Rooted mean squared error RMSE

\[RMSE(f) = \sqrt{MSE(f, X, y)} \]

For classification, commonly used measures are Accuracy

\[ACC(f) = (TP + TN)/n\]

Precision

\[Prec(f) = TP/(TP + FP)\]

and Recall

\[Recall(f) = TP/(TP + FN)\]

and F1 score

\[F1(f) = 2\frac{Prec(f) * Recall(f) }{Prec(f) + Recall(f)}\]

In this problem we are interested in ranking of scores, so we will use the AUC measure (the area under the ROC curve).

There are many measures for evaluating predictive models and they are located in various R packages (ROCR, measures, mlr3measures, etc.). For simplicity, in this example, we use only the AUC measure from the DALEX package.

Pregnancy: Sensitivity and Specificity

http://getthediagnosis.org/diagnosis/Pregnancy.htm

https://en.wikipedia.org/wiki/Sensitivity_and_specificity

For AUC the cutoff does not matter. But we set it to get nice precision and F1.

More info

Model performance

Model exploration starts with an assessment of how good is the model. The DALEX::model_performance function calculates a set of the most common measures for the specified model.

mp_cdc <- model_performance(model_cdc, cutoff = 0.1)
mp_cdc
## Measures for:  classification
## recall     : 0.2188841 
## precision  : 0.2602041 
## f1         : 0.2377622 
## accuracy   : 0.9673 
## auc        : 0.904654
## 
## Residuals:
##       0%      10%      20%      30%      40%      50%      60%      70% 
## -0.23700 -0.03300 -0.01200 -0.01200 -0.00390 -0.00390 -0.00135 -0.00135 
##      80%      90%     100% 
## -0.00045 -0.00006  0.99955

ROC

Note: The model is evaluated on the data given in the explainer. Use DALEX::update_data() to specify another dataset.

Note: Explainer knows whether the model is for classification or regression, so it automatically selects the right measures. It can be overridden if needed.

The S3 generic plot function draws a graphical summary of the model performance. With the geom argument, one can determine the type of chart.

More info

plot(mp_cdc, geom = "roc")

LIFT

More info

plot(mp_cdc, geom = "lift")

Your turn

  • Calcualte the AUC for CDC model on the covid_spring data.
  • Plot ROC for both covid_spring and covid_summer data.
  • (extra) In the DALEX package you will find titanic_imputed dataset. Build a similar model to CDC but for the Titanic dataset. How good is your model?

Part 3: Basics of decision tree and random forest

Usually, we don’t know which function is the best for our problem. This is why we want to use data to find/train such function with the use of some automated algorithm.

In the Machine Learning, there are hundreds of algorithms available. Usually, this training boils down to finding parameters for some family of models. One of the most popular families of models is decision trees. Their great advantage is the transparency of their structure.

We will begin building the model by constructing a decision tree. We will stepwise control the complexity of the model.

More info

library("partykit")

tree1 <- ctree(Death ~., covid_spring, 
              control = ctree_control(maxdepth = 1))
plot(tree1)

tree2 <- ctree(Death ~., covid_spring, 
              control = ctree_control(maxdepth = 2))
plot(tree2)

tree3 <- ctree(Death ~., covid_spring, 
              control = ctree_control(maxdepth = 3))
plot(tree3)

tree <- ctree(Death ~., covid_spring, 
              control = ctree_control(alpha = 0.0001))
plot(tree)

To work with different models uniformly, we will also wrap this one into an explainer.

model_tree <-  DALEX::explain(tree,
                   predict_function = function(m, x) predict(m, x, type = "prob")[,2],
                   data = covid_summer,
                   y = covid_summer$Death == "Yes",
                   type = "classification",
                   label = "Tree",
                   verbose = FALSE)

Test your model

mp_tree <- model_performance(model_tree, cutoff = 0.1)
mp_tree
## Measures for:  classification
## recall     : 0.8626609 
## precision  : 0.1492205 
## f1         : 0.2544304 
## accuracy   : 0.8822 
## auc        : 0.9136169
## 
## Residuals:
##           0%          10%          20%          30%          40%          50% 
## -0.462686567 -0.199029126 -0.007781621 -0.007781621 -0.007781621 -0.007781621 
##          60%          70%          80%          90%         100% 
## -0.007781621 -0.007781621 -0.007781621 -0.007781621  0.992218379
plot(mp_tree, geom = "roc")

plot(mp_tree, mp_cdc, geom = "roc")

Your turn

  • Check the AUC for CDC model on the covid_spring data.
  • Plot ROC for both covid_spring and covid_summer data.
  • (*)Try to overfit.

Plant a forest

Decision trees are models that have low bias but high variance. In 2001, Leo Breiman proposed a new family of models, called a random forest, which averages scores from multiple decision trees trained on bootstrap samples of the data. The whole algorithm is a bit more complex but also very fascinating. You can read about it at https://tinyurl.com/RF2001. Nowadays a very popular, in a sense complementary technique for improving models is boosting, in which you reduce the model load at the expense of variance. This algorithm reduces variance at the expense of bias. Quite often it leads to a better model.

We will train a random forest with the mlr3 library. The first step is to define the prediction task. More info

library("mlr3")

covid_task <- TaskClassif$new(id = "covid_spring",
                             backend = covid_spring,
                             target = "Death",
                             positive = "Yes")
covid_task
## <TaskClassif:covid_spring> (10000 x 8)
## * Target: Death
## * Properties: twoclass
## * Features (7):
##   - fct (6): Cancer, Cardiovascular.Diseases, Diabetes, Gender,
##     Kidney.Diseases, Neurological.Diseases
##   - int (1): Age

Now we need to define the family of models in which we want to look for a solution. The random forests is specified by the classif.ranger" parameter. To find the best model in this family we use the train().

More info

library("mlr3learners")
library("ranger")

covid_ranger <- lrn("classif.ranger", predict_type = "prob",
                num.trees = 25)
covid_ranger
## <LearnerClassifRanger:classif.ranger>
## * Model: -
## * Parameters: num.trees=25
## * Packages: ranger
## * Predict Type: prob
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: importance, multiclass, oob_error, twoclass, weights
covid_ranger$train(covid_task)

Test your model

A trained model can be turned into an explainer. Simpler functions can be used to calculate the performance of this model. But using explainers has an advantage that will be seen in all its beauty in just two pages.

model_ranger <-  explain(covid_ranger,
                           predict_function = function(m,x)
                                predict(m, x, predict_type = "prob")[,1],
                           data = covid_summer,
                           y = covid_summer$Death == "Yes",
                           type = "classification",
                           label = "Ranger",
                           verbose = FALSE)

mp_ranger <- model_performance(model_ranger)
mp_ranger
## Measures for:  classification
## recall     : 0.111588 
## precision  : 0.4482759 
## f1         : 0.1786942 
## accuracy   : 0.9761 
## auc        : 0.9425837
## 
## Residuals:
##           0%          10%          20%          30%          40%          50% 
## -0.633536505 -0.129990764 -0.018303930 -0.011174022 -0.011142358 -0.011065548 
##          60%          70%          80%          90%         100% 
## -0.009678616 -0.007173015 -0.007121335 -0.007089671  0.992878665
plot(mp_ranger, geom = "roc")

plot(mp_ranger, mp_tree, mp_cdc, geom = "roc")

Your turn

  • Check the AUC for Ranger model on the covid_spring data.
  • Plot ROC for both covid_spring and covid_summer data.
  • (extra) In the DALEX package you will find titanic_imputed dataset. Build a tree based model for the Titanic dataset. How good is your model?

Part 4: Hyperparameter optimization + Wrap-up

Hyperparameter Optimisation

Machine Learning algorithms typically have many hyperparameters that determine how the model is to be trained. For models with high variance, the selection of such hyperparameters has a strong impact on the quality of the final solution. The mlr3tuning package contains procedures to automate the process of finding good hyperparameters.

See: https://mlr3book.mlr-org.com/tuning.html.

To use it, you must specify the space of hyperparameter to search. Not all hyperparameters are worth optimizing. In the example below, we focus on four for the random forest algorithm.

Automated Hyperparameter Optimisation

For automatic hyperparameter search, it is necessary to specify a few more elements: (1) a stopping criterion, below it is the number of 10 evaluations, (2) a search strategy for the parameter space, below it is a random search, (3) a way to evaluate the performance of the proposed models, below it is the AUC determined by 5-fold cross-validation.

Define the search space

In order to be able to automatically search for optimal parameters, it is first necessary to specify what is the space of possible hyperparameters.

More info

library("mlr3tuning")
library("paradox")

search_space = ps(
  num.trees = p_int(lower = 50, upper = 500),
  max.depth = p_int(lower = 1, upper = 10),
  minprop = p_dbl(lower = 0.01, upper = 0.1),
  splitrule = p_fct(levels = c("gini", "extratrees"))
)
search_space
## <ParamSet>
##           id    class lower upper          levels        default value
## 1: num.trees ParamInt 50.00 500.0                 <NoDefault[3]>      
## 2: max.depth ParamInt  1.00  10.0                 <NoDefault[3]>      
## 3:   minprop ParamDbl  0.01   0.1                 <NoDefault[3]>      
## 4: splitrule ParamFct    NA    NA gini,extratrees <NoDefault[3]>

Set-up the tuner

Popular searching strategies are random_search and grid_search. Termination is set fo a specific number of evaluations. Internal testing is based on 5-fold CV.

More info

tuned_ranger = AutoTuner$new(
  learner    = covid_ranger,
  resampling = rsmp("cv", folds = 5),
  measure    = msr("classif.auc"),
  search_space = search_space,
  terminator = trm("evals", n_evals = 10),
  tuner    = tnr("random_search")
)
tuned_ranger
## <AutoTuner:classif.ranger.tuned>
## * Model: -
## * Parameters: list()
## * Packages: ranger
## * Predict Type: prob
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: importance, multiclass, oob_error, twoclass, weights

Tune

tuned_ranger$train(covid_task)
tuned_ranger$tuning_result
##    num.trees max.depth    minprop splitrule learner_param_vals  x_domain
## 1:       264         9 0.06907318      gini          <list[4]> <list[4]>
##    classif.auc
## 1:   0.9272979
tuned_ranger$predict_newdata(newdata = covid_spring)$prob[1:4,]
##              Yes        No
## [1,] 0.009092137 0.9909079
## [2,] 0.009700991 0.9902990
## [3,] 0.009164498 0.9908355
## [4,] 0.009164498 0.9908355

Test your model

model_tuned <-  explain(tuned_ranger,
                           predict_function = function(m,x)
                               m$predict_newdata(newdata = x)$prob[,1],
                           data = covid_summer,
                           y = covid_summer$Death == "Yes",
                           type = "classification",
                           label = "AutoTune",
                           verbose = FALSE)

mp_tuned <- model_performance(model_tuned)
mp_tuned
## Measures for:  classification
## recall     : 0.05150215 
## precision  : 0.4285714 
## f1         : 0.09195402 
## accuracy   : 0.9763 
## auc        : 0.9447171
## 
## Residuals:
##           0%          10%          20%          30%          40%          50% 
## -0.592588661 -0.127894604 -0.022977442 -0.009700991 -0.009164498 -0.008951611 
##          60%          70%          80%          90%         100% 
## -0.008855070 -0.007352912 -0.007157637 -0.007001973  0.992859130
plot(mp_tuned, geom = "roc")

plot(mp_ranger, mp_tree, mp_cdc, mp_tuned, geom = "roc")

Sum up

do.call(rbind, 
        list(cdc   = mp_cdc$measures,
            tree   = mp_tree$measures,
            ranger = mp_ranger$measures,
            tuned  = mp_tuned$measures))
##        recall     precision f1         accuracy auc      
## cdc    0.2188841  0.2602041 0.2377622  0.9673   0.904654 
## tree   0.8626609  0.1492205 0.2544304  0.8822   0.9136169
## ranger 0.111588   0.4482759 0.1786942  0.9761   0.9425837
## tuned  0.05150215 0.4285714 0.09195402 0.9763   0.9447171

Your turn

  • Check the AUC for AutoTune model on the covid_spring data.
  • Plot ROC for both covid_spring and covid_summer data.
  • (extra) In the DALEX package you will find titanic_imputed dataset. Optimize a tree based model for the Titanic dataset. How good is your model?

Day 2: Model exploration

We will devote the second day entirely to talking about methods for model exploration.

More info

DALEX piramide

Part 1: Model level analysis - variable importance

Some models have built-in methods for assessment of Variable importance. For linear models one can use standardized model coefficients or p-values. For random forest one can use out-of-bag classification error. For tree boosting models one can use gain statistics. Yet, problem with such measures is that not all models have build-in variable importance statistics (e.g. neural networks) and that scores between differetnt models cannot be directly compared (how to compare gains with p-values).

This is why we need a model agnostic approach that will be comparable between different models. The procedure described below is universal, model agnostic and does not depend on the model structure.

The procedure is based on variable perturbations in the validation data. If a variable is important in a model, then after its permutation the model predictions should be less accurate.

The permutation-based variable-importance of a variable \(i\) is the difference between the model performance for the original data and the model performance measured on data with the permutated variable \(i\)

\[ VI(i) = L(f, X^{perm(i)}, y) - L(f, X, y) \]

where \(L(f, X, y)\) is the value of loss function for original data \(X\), true labels \(y\) and model \(f\), while \(X^{perm(i)}\) is dataset \(x\) with \(i\)-th variable permutated.

Which performance measure should you choose? It’s up to you. In the DALEX library, by default, RMSE is used for regression and 1-AUC for classification problems. But you can change the loss function by specifying the argument.

More info

mpart_ranger <- model_parts(model_ranger)
mpart_ranger
##                   variable mean_dropout_loss  label
## 1             _full_model_        0.05522936 Ranger
## 2                   Gender        0.05449878 Ranger
## 3                    Death        0.05522936 Ranger
## 4    Neurological.Diseases        0.05573063 Ranger
## 5                 Diabetes        0.05871462 Ranger
## 6          Kidney.Diseases        0.05922719 Ranger
## 7                   Cancer        0.06361618 Ranger
## 8  Cardiovascular.Diseases        0.08066089 Ranger
## 9                      Age        0.19733230 Ranger
## 10              _baseline_        0.51057115 Ranger
plot(mpart_ranger, show_boxplots = FALSE, bar_width=4) +
  DALEX:::theme_ema_vertical() + 
  theme( axis.text = element_text(color = "black", size = 12, hjust = 0)) +
  ggtitle("Variable importance","")

mpart_ranger <- model_parts(model_ranger, type = "difference")
mpart_ranger
##                   variable mean_dropout_loss  label
## 1             _full_model_      0.0000000000 Ranger
## 2                 Diabetes     -0.0003780217 Ranger
## 3                    Death      0.0000000000 Ranger
## 4          Kidney.Diseases      0.0004828034 Ranger
## 5    Neurological.Diseases      0.0009009033 Ranger
## 6                   Gender      0.0073552202 Ranger
## 7                   Cancer      0.0083969538 Ranger
## 8  Cardiovascular.Diseases      0.0137486420 Ranger
## 9                      Age      0.1609372266 Ranger
## 10              _baseline_      0.4538053702 Ranger
plot(mpart_ranger, show_boxplots = FALSE, bar_width=4) +
  DALEX:::theme_ema_vertical() + 
  theme( axis.text = element_text(color = "black", size = 12, hjust = 0)) +
  ggtitle("Variable importance","")

mpart_cdc <- model_parts(model_cdc)
mpart_tree <- model_parts(model_tree)
mpart_tuned <- model_parts(model_tuned)

plot(mpart_cdc, mpart_tree, mpart_ranger, mpart_tuned, show_boxplots = FALSE, bar_width=4) +
  DALEX:::theme_ema_vertical() + 
  theme( axis.text = element_text(color = "black", size = 12, hjust = 0)) +
  ggtitle("Variable importance","")

Your turn

  • Compare results for covid_summer with results on the covid_spring data.
  • (extra) In the DALEX package you will find titanic_imputed dataset. Train a ranger model and calculate variable importance.

Part 2: Model level analysis - variable profile

Partial dependence profiles are averages from CP profiles for all (or a large enough number) observations.

The model_profiles() function calculates PD profiles for a~specified model and variables (all by default).

More info

mprof_cdc <- model_profile(model_cdc, "Age")
plot(mprof_cdc)

mgroup_ranger <- model_profile(model_ranger, variable_splits = list(Age = 0:100))
plot(mgroup_ranger)+
  DALEX:::theme_ema() + 
  theme( axis.text = element_text(color = "black", size = 12, hjust = 0)) +
  ggtitle("Variable profile","")

Grouped partial dependence profiles

By default, the average is calculated for all observations. But with the argument groups= one can specify a factor variable in which CP profiles will be averaged.

mgroup_ranger <- model_profile(model_ranger, variable_splits = list(Age = 0:100), groups = "Cardiovascular.Diseases")
plot(mgroup_ranger)+
  DALEX:::theme_ema() + 
  theme( axis.text = element_text(color = "black", size = 12, hjust = 0)) +
  ggtitle("PDP variable profile","") + ylab("") + theme(legend.position = "top")

mgroup_ranger <- model_profile(model_ranger, "Age", k = 3, center = TRUE)
plot(mgroup_ranger)+
  DALEX:::theme_ema() + 
  theme( axis.text = element_text(color = "black", size = 12, hjust = 0)) +
  ggtitle("Variable profile","")

mprof_cdc <- model_profile(model_cdc, variable_splits = list(Age=0:100))
mprof_tree <- model_profile(model_tree, variable_splits = list(Age=0:100))
mprof_ranger <- model_profile(model_ranger, variable_splits = list(Age=0:100))
mprof_tuned <- model_profile(model_tuned, variable_splits = list(Age=0:100))

Profiles can be then drawn with the plot() function.

plot(mprof_tuned, mprof_cdc, mprof_tree, mprof_ranger)

If the model is additive, all CP profiles are parallel. But if the model has interactions, CP profiles may have different shapes for different observations. Defining the k argument allows to find and calculate the average in k segments of CP profiles.

PDP profiles do not take into account the correlation structure between the variables. For correlated variables, the Ceteris paribus assumption may not make sense. The model_profile function can also calculate other types of aggregates, such as marginal profiles and accumulated local profiles. To do this, specify the argument type= for "conditional" or "accumulated".

Your turn

  • Compare results for covid_summer with results on the covid_spring data.
  • (extra) In the DALEX package you will find titanic_imputed dataset. Train a ranger model and plot variable profiles.

Part 3: Instance level analysis - variable attributions

Once we calculate the model prediction, the question often arises which variables had the greatest impact on it.

For linear models it is easy to assess the impact of individual variables because there is one coefficient for each variable.

More info

john <- data.frame(Gender = factor("Male", levels = c("Male", "Female")),
                   Age = 76,
                   Cardiovascular.Diseases = factor("Yes", levels = c("Yes", "No")), 
                   Diabetes = factor("No", levels = c("Yes", "No")), 
                   Neurological.Diseases = factor("No", levels = c("Yes", "No")), 
                   Kidney.Diseases = factor("No", levels = c("Yes", "No")), 
                   Cancer = factor("No", levels = c("Yes", "No")))
john
##   Gender Age Cardiovascular.Diseases Diabetes Neurological.Diseases
## 1   Male  76                     Yes       No                    No
##   Kidney.Diseases Cancer
## 1              No     No

It turns out that such attributions can be calculated for any predictive model. The most popular model agnostic method is Shapley values. They may be calculated with a predict_parts() function.

More info

ppart_cdc <- predict_parts(model_cdc, john, type = "shap")
plot(ppart_cdc)

ppart_tree <- predict_parts(model_tree, john, type = "shap")
plot(ppart_tree)

ppart_ranger <- predict_parts(model_ranger, john, type = "shap")
plot(ppart_ranger)

ppart_tuned <- predict_parts(model_tuned, john, type = "shap")
plot(ppart_tuned)

The show_boxplots argument allows you to highlight the stability bars of the estimated attributions.

Other possible values of the type argument are oscillations, shap, break_down, break_down_interactions.

With order one can force a certain sequence of variables.

By default, functions such as model_parts, predict_parts, model_profiles do not calculate statistics on the entire data set, but on n_samples of random cases, and the entire procedure is repeated B times to estimate the error bars.

ppart_cdc <- predict_parts(model_cdc, john)
plot(ppart_cdc)

Your turn

  • Compare results for covid_summer with results on the covid_spring data.
  • (extra) In the DALEX package you will find titanic_imputed dataset. Train a ranger model and calculate local variable attribution.

Part 4: Instance level analysis - variable profile + Wrap-up

Profile for a single prediction

Ceteris Paribus is a Latin phrase for "other things being equal.

Ceteris-paribus profiles show how the model response would change for a~selected observation if one of the coordinates of that observation were changed while leaving the other coordinates unchanged.

The predict_profiles() function calculated CP profiles for a selected observation, model and vector of variables (all continuous variables by default).

More info

mprof_cdc <- predict_profile(model_cdc, john, "Age")
plot(mprof_cdc)

CP profiles can be visualized with the generic function.

For technical reasons, quantitative and qualitative variables cannot be shown in a single chart. So if you want to show the importance of quality variables you need to plot them separately.

mprof_cdc <- predict_profile(model_cdc, variable_splits = list(Age=0:100), john)
mprof_tree <- predict_profile(model_tree, variable_splits = list(Age=0:100), john)
mprof_ranger <- predict_profile(model_ranger, variable_splits = list(Age=0:100), john)
mprof_tuned <- predict_profile(model_tuned, variable_splits = list(Age=0:100), john)

plot(mprof_tuned, mprof_cdc, mprof_tree, mprof_ranger)

Local importance of variables can be measured as oscillations of CP plots. The greater the variability of the CP profile, the more important is the variable. Set type = "oscillations" in the predict_parts function.

Your turn

  • Compare results for covid_summer with results on the covid_spring data.
  • (extra) In the DALEX package you will find titanic_imputed dataset. Train a ranger model and calculate local profiles.

Extras

Play with your model!

More info

library("modelStudio")

ms <- modelStudio(model_ranger)
ms

Session info

devtools::session_info()
## ─ Session info ───────────────────────────────────────────────────────────────
##  setting  value                       
##  version  R version 4.0.2 (2020-06-22)
##  os       macOS Catalina 10.15.7      
##  system   x86_64, darwin17.0          
##  ui       X11                         
##  language (EN)                        
##  collate  en_US.UTF-8                 
##  ctype    en_US.UTF-8                 
##  tz       Europe/Warsaw               
##  date     2021-04-17                  
## 
## ─ Packages ───────────────────────────────────────────────────────────────────
##  package        * version date       lib source        
##  assertthat       0.2.1   2019-03-21 [1] CRAN (R 4.0.0)
##  backports        1.1.10  2020-09-15 [1] CRAN (R 4.0.2)
##  bbotk            0.3.0   2021-01-24 [1] CRAN (R 4.0.2)
##  callr            3.5.1   2020-10-13 [1] CRAN (R 4.0.2)
##  checkmate        2.0.0   2020-02-06 [1] CRAN (R 4.0.0)
##  class            7.3-17  2020-04-26 [1] CRAN (R 4.0.2)
##  cli              2.3.0   2021-01-31 [1] CRAN (R 4.0.2)
##  codetools        0.2-16  2018-12-24 [1] CRAN (R 4.0.2)
##  colorspace       2.0-0   2020-11-11 [1] CRAN (R 4.0.2)
##  crayon           1.4.1   2021-02-08 [1] CRAN (R 4.0.2)
##  DALEX          * 2.2.0   2021-03-20 [1] CRAN (R 4.0.2)
##  data.table       1.14.0  2021-02-21 [1] CRAN (R 4.0.2)
##  DBI              1.1.1   2021-01-15 [1] CRAN (R 4.0.2)
##  desc             1.2.0   2018-05-01 [1] CRAN (R 4.0.0)
##  devtools         2.3.2   2020-09-18 [1] CRAN (R 4.0.2)
##  digest           0.6.27  2020-10-24 [1] CRAN (R 4.0.2)
##  dplyr            1.0.2   2020-08-18 [1] CRAN (R 4.0.2)
##  e1071            1.7-4   2020-10-14 [1] CRAN (R 4.0.2)
##  ellipsis         0.3.1   2020-05-15 [1] CRAN (R 4.0.0)
##  evaluate         0.14    2019-05-28 [1] CRAN (R 4.0.0)
##  farver           2.0.3   2020-01-16 [1] CRAN (R 4.0.0)
##  forcats          0.5.0   2020-03-01 [1] CRAN (R 4.0.0)
##  Formula          1.2-4   2020-10-16 [1] CRAN (R 4.0.2)
##  fs               1.5.0   2020-07-31 [1] CRAN (R 4.0.2)
##  future           1.19.1  2020-09-22 [1] CRAN (R 4.0.2)
##  future.apply     1.6.0   2020-07-01 [1] CRAN (R 4.0.0)
##  generics         0.1.0   2020-10-31 [1] CRAN (R 4.0.2)
##  ggplot2        * 3.3.3   2020-12-30 [1] CRAN (R 4.0.2)
##  globals          0.13.1  2020-10-11 [1] CRAN (R 4.0.2)
##  glue             1.4.2   2020-08-27 [1] CRAN (R 4.0.2)
##  gtable           0.3.0   2019-03-25 [1] CRAN (R 4.0.0)
##  haven            2.3.1   2020-06-01 [1] CRAN (R 4.0.0)
##  hms              0.5.3   2020-01-08 [1] CRAN (R 4.0.0)
##  htmltools        0.5.0   2020-06-16 [1] CRAN (R 4.0.1)
##  iBreakDown       1.3.1   2020-07-29 [1] CRAN (R 4.0.2)
##  ingredients      2.0.1   2021-02-05 [1] CRAN (R 4.0.2)
##  inum             1.0-1   2019-04-25 [1] CRAN (R 4.0.0)
##  knitr            1.30    2020-09-22 [1] CRAN (R 4.0.2)
##  labeling         0.4.2   2020-10-20 [1] CRAN (R 4.0.2)
##  labelled         2.7.0   2020-09-21 [1] CRAN (R 4.0.2)
##  lattice          0.20-41 2020-04-02 [1] CRAN (R 4.0.2)
##  lgr              0.4.1   2020-10-20 [1] CRAN (R 4.0.2)
##  libcoin        * 1.0-6   2020-08-14 [1] CRAN (R 4.0.2)
##  lifecycle        1.0.0   2021-02-15 [1] CRAN (R 4.0.2)
##  listenv          0.8.0   2019-12-05 [1] CRAN (R 4.0.0)
##  magrittr         2.0.1   2020-11-17 [1] CRAN (R 4.0.2)
##  MASS             7.3-53  2020-09-09 [1] CRAN (R 4.0.2)
##  Matrix           1.2-18  2019-11-27 [1] CRAN (R 4.0.2)
##  memoise          1.1.0   2017-04-21 [1] CRAN (R 4.0.0)
##  mitools          2.4     2019-04-26 [1] CRAN (R 4.0.0)
##  mlr3           * 0.11.0  2021-03-05 [1] CRAN (R 4.0.2)
##  mlr3learners   * 0.4.1   2020-10-07 [1] CRAN (R 4.0.2)
##  mlr3measures     0.3.0   2020-10-05 [1] CRAN (R 4.0.2)
##  mlr3misc         0.7.0   2021-01-05 [1] CRAN (R 4.0.2)
##  mlr3tuning     * 0.7.0   2021-02-11 [1] CRAN (R 4.0.2)
##  munsell          0.5.0   2018-06-12 [1] CRAN (R 4.0.0)
##  mvtnorm        * 1.1-1   2020-06-09 [1] CRAN (R 4.0.0)
##  palmerpenguins   0.1.0   2020-07-23 [1] CRAN (R 4.0.2)
##  paradox        * 0.7.0   2021-01-23 [1] CRAN (R 4.0.2)
##  parallelly       1.23.0  2021-01-04 [1] CRAN (R 4.0.2)
##  partykit       * 1.2-10  2020-10-12 [1] CRAN (R 4.0.2)
##  pheatmap       * 1.0.12  2019-01-04 [1] CRAN (R 4.0.2)
##  pillar           1.4.7   2020-11-20 [1] CRAN (R 4.0.2)
##  pkgbuild         1.2.0   2020-12-15 [1] CRAN (R 4.0.2)
##  pkgconfig        2.0.3   2019-09-22 [1] CRAN (R 4.0.0)
##  pkgload          1.1.0   2020-05-29 [1] CRAN (R 4.0.0)
##  prettyunits      1.1.1   2020-01-24 [1] CRAN (R 4.0.0)
##  processx         3.4.5   2020-11-30 [1] CRAN (R 4.0.2)
##  ps               1.5.0   2020-12-05 [1] CRAN (R 4.0.2)
##  purrr            0.3.4   2020-04-17 [1] CRAN (R 4.0.0)
##  R6               2.5.0   2020-10-28 [1] CRAN (R 4.0.2)
##  ranger         * 0.12.1  2020-01-10 [1] CRAN (R 4.0.0)
##  RColorBrewer     1.1-2   2014-12-07 [1] CRAN (R 4.0.2)
##  Rcpp             1.0.6   2021-01-15 [1] CRAN (R 4.0.2)
##  remotes          2.2.0   2020-07-21 [1] CRAN (R 4.0.2)
##  rlang            0.4.10  2020-12-30 [1] CRAN (R 4.0.2)
##  rmarkdown        2.4     2020-09-30 [1] CRAN (R 4.0.2)
##  rpart            4.1-15  2019-04-12 [1] CRAN (R 4.0.2)
##  rprojroot        2.0.2   2020-11-15 [1] CRAN (R 4.0.2)
##  scales           1.1.1   2020-05-11 [1] CRAN (R 4.0.0)
##  sessioninfo      1.1.1   2018-11-05 [1] CRAN (R 4.0.0)
##  stringi          1.5.3   2020-09-09 [1] CRAN (R 4.0.2)
##  stringr          1.4.0   2019-02-10 [1] CRAN (R 4.0.0)
##  survey           4.0     2020-04-03 [1] CRAN (R 4.0.0)
##  survival         3.2-7   2020-09-28 [1] CRAN (R 4.0.2)
##  tableone       * 0.12.0  2020-07-26 [1] CRAN (R 4.0.2)
##  testthat         3.0.2   2021-02-14 [1] CRAN (R 4.0.2)
##  tibble           3.0.6   2021-01-29 [1] CRAN (R 4.0.2)
##  tidyselect       1.1.0   2020-05-11 [1] CRAN (R 4.0.0)
##  usethis          1.6.3   2020-09-17 [1] CRAN (R 4.0.2)
##  uuid             0.1-4   2020-02-26 [1] CRAN (R 4.0.0)
##  vctrs            0.3.6   2020-12-17 [1] CRAN (R 4.0.2)
##  withr            2.4.1   2021-01-26 [1] CRAN (R 4.0.2)
##  xfun             0.18    2020-09-29 [1] CRAN (R 4.0.2)
##  yaml             2.2.1   2020-02-01 [1] CRAN (R 4.0.0)
##  zoo              1.8-8   2020-05-02 [1] CRAN (R 4.0.0)
## 
## [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library